{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# [23、手写并验证滑动相乘实现PyTorch二维卷积](https://www.bilibili.com/video/BV1dP4y137er)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "卷积示意图：\n",
    "![卷积示意图](http://assets.hypervoid.top/img/2025/07/05/202507051308112-6a79.png)\n",
    "pad=1，步长=2：\n",
    "![pad=1，步长=2](http://assets.hypervoid.top/img/2025/07/05/202507051310881-034c.png)\n",
    "\n",
    "多通道卷积：\n",
    "例如图片有RGB三个通道，就需要 `3*k*k` 大小的卷积核，然后将三个通道加起来作为**一个**out。\n",
    "![](http://assets.hypervoid.top/img/2025/07/05/202507051314911-2593.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Optional\n",
    "import torch\n",
    "from torch import Tensor\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def my_conv(\n",
    "    input_mat: torch.Tensor,\n",
    "    kernel_mat: Tensor,\n",
    "    bias: Optional[Tensor] = None,\n",
    "    stride: int = 1,\n",
    "    padding: int = 0,\n",
    ") -> Tensor:\n",
    "    \"\"\"\n",
    "    pytorch中 conv2d 输入是4维度卷积，这里是二维\n",
    "\n",
    "    Args:\n",
    "        input_mat (torch.Tensor): 2维矩阵\n",
    "        kernel_mat (Tensor): 2维的卷积核\n",
    "        bias (Optional[Tensor], optional): Defaults to None.\n",
    "        stride (int, optional): 步长. Defaults to 1.\n",
    "        padding (int, optional): 边界周围填充. Defaults to 0.\n",
    "\n",
    "    \"\"\"\n",
    "    if padding != 0:\n",
    "        input_mat = F.pad(input_mat, [padding, padding, padding, padding])\n",
    "    w_in, h_in = input_mat.shape\n",
    "    w_ker, h_ker = kernel_mat.shape\n",
    "\n",
    "    h_out = (h_in - h_ker) // stride + 1\n",
    "    w_out = (w_in - w_ker) // stride + 1\n",
    "    out = torch.zeros(w_out, h_out)\n",
    "    for i in range(w_out):\n",
    "        i_s = i * stride\n",
    "        for j in range(h_out):\n",
    "            j_s = j * stride\n",
    "            submat = input_mat[i_s : i_s + w_ker, j_s : j_s + h_ker]\n",
    "            out[i, j] += (submat * kernel_mat).sum()\n",
    "    if bias:\n",
    "        out += bias\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All Test Passed!\n",
      "All Test Passed!\n",
      "All Test Passed!\n"
     ]
    }
   ],
   "source": [
    "def test_conv():\n",
    "    mat = torch.randn((4, 4))\n",
    "    ker = torch.randn((3, 3))\n",
    "    r1 = my_conv(mat, ker)\n",
    "    r2 = F.conv2d(\n",
    "        mat.unsqueeze(0).unsqueeze(0), ker.unsqueeze(0).unsqueeze(0)\n",
    "    ).squeeze()\n",
    "    assert r1.allclose(r2)\n",
    "    # ----------------\n",
    "    mat = torch.randn((5, 5))\n",
    "    ker = torch.randn((3, 3))\n",
    "    r1 = my_conv(mat, ker, padding=1)\n",
    "    mat.unsqueeze_(0).unsqueeze_(0)\n",
    "    ker.unsqueeze_(0).unsqueeze_(0)\n",
    "    r2 = F.conv2d(mat, ker, padding=1, stride=1).squeeze()\n",
    "    assert r1.allclose(r2)\n",
    "    # ----------------\n",
    "    mat = torch.randn((8, 8))\n",
    "    ker = torch.randn((5, 5))\n",
    "    r1 = my_conv(mat, ker, padding=1,stride=2)\n",
    "    mat.unsqueeze_(0).unsqueeze_(0)\n",
    "    ker.unsqueeze_(0).unsqueeze_(0)\n",
    "    r2 = F.conv2d(mat, ker, padding=1, stride=2).squeeze()\n",
    "    assert r1.allclose(r2)\n",
    "    # ----------------\n",
    "    mat = torch.randn((8, 8))\n",
    "    ker = torch.randn((3, 3))\n",
    "    r1 = my_conv(mat, ker, padding=1,stride=2)\n",
    "    mat.unsqueeze_(0).unsqueeze_(0)\n",
    "    ker.unsqueeze_(0).unsqueeze_(0)\n",
    "    r2 = F.conv2d(mat, ker, padding=1, stride=2).squeeze()\n",
    "    assert r1.allclose(r2)\n",
    "    # ----------------\n",
    "    print(\"All Test Passed!\")\n",
    "    \n",
    "test_conv()\n",
    "test_conv()\n",
    "test_conv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def my_conv2d(\n",
    "    input: Tensor,\n",
    "    kernel: Tensor,\n",
    "    bias: Optional[Tensor] = None,\n",
    "    stride: int = 1,\n",
    "    padding: int = 0,\n",
    ") -> Tensor:\n",
    "    \"\"\"\n",
    "\n",
    "    Args:\n",
    "        input_mat (torch.Tensor): 4维矩阵 (batch_size, input_channels, weight, height)\n",
    "        kernel_mat (Tensor): 2维的卷积核 (out_channels, weight, height)\n",
    "        bias (Optional[Tensor], optional): Defaults to None.\n",
    "        stride (int, optional): 步长. Defaults to 1.\n",
    "        padding (int, optional): 边界周围填充. Defaults to 0.\n",
    "\n",
    "    \"\"\"\n",
    "    assert input.dim() == 4\n",
    "    batch_size, input_channels, w_in, h_in = input.shape\n",
    "    out_channels, w_ker, h_ker = kernel.shape\n",
    "\n",
    "    w_out = (w_in - w_ker) // stride + 1\n",
    "    h_out = (h_in - h_ker) // stride + 1\n",
    "    out = torch.zeros(batch_size, out_channels, w_out, h_out)\n",
    "    for i in range(batch_size):\n",
    "        for oc in range(out_channels):\n",
    "            kernel_i = kernel[oc, :, :]\n",
    "            for ic in range(input_channels):\n",
    "                out[i, oc, :, :] += my_conv(input[i, ic, :, :], kernel_i, bias, stride, padding)\n",
    "\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All Test Passed!\n"
     ]
    }
   ],
   "source": [
    "def test_conv():\n",
    "    mat = torch.randn((3, 1, 4, 4))\n",
    "    ker = torch.randn((2, 3, 3))\n",
    "    r1 = my_conv2d(mat, ker)\n",
    "    r2 = F.conv2d(mat, ker.unsqueeze(1), stride=1)\n",
    "    assert r1.allclose(r2)\n",
    "    print(\"All Test Passed!\")\n",
    "\n",
    "\n",
    "test_conv()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
